/**
 * 
 */
package edu.umd.clip.lm.model.data;

import java.util.*;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.*;
import java.io.*;
import java.nio.ByteBuffer;

import edu.berkeley.nlp.util.MutableInteger;
import edu.umd.clip.lm.util.*;
import edu.umd.clip.lm.model.*;
import edu.umd.clip.lm.factors.*;

/**
 * @author Denis Filimonov <den@cs.umd.edu>
 *
 */
public abstract class AbstractTrainingData {
	private byte contextSize = -1;
	protected static AtomicInteger idSource = new AtomicInteger();
	private final int id;
	protected static final int HEADER_SIZE = 1;

	protected static int getNextId() {
		return idSource.incrementAndGet();
	}

	public int getId() {
		return id;
	}
	
	public AbstractTrainingData() {
		id = idSource.incrementAndGet();
	}

	public byte getContextSize() {
		return contextSize;
	}

	public void setContextSize(byte contextSize) {
		this.contextSize = contextSize;
	}

	protected void readHeader(ByteBuffer buffer) {
		assert(buffer.remaining() >= HEADER_SIZE);
		contextSize = buffer.get();
	}

	protected void writeHeader(ByteBuffer buffer) {
		assert(buffer.remaining() >= HEADER_SIZE);
		buffer.put(contextSize);
	}
	
	public static void parse(final WritableTrainingData trainingData, LanguageModel lm, BufferedReader input, int jobs, int maxCacheSize) throws IOException {
		parse(trainingData, lm.getOvertOrder(), lm.getHiddenOrder(), input, jobs, maxCacheSize, false);
	}
	
	public static void parse(final WritableTrainingData trainingData, LanguageModel lm, BufferedReader input, int jobs, int maxCacheSize, boolean mergeLines) throws IOException {
		parse(trainingData, lm.getOvertOrder(), lm.getHiddenOrder(), input, jobs, maxCacheSize, mergeLines);
	}
	
	public static void parse(final WritableTrainingData trainingData, int overtOrder, 
			int hiddenOrder, BufferedReader input, int jobs, int maxCacheSize) throws IOException 
	{
		parse(trainingData, overtOrder, hiddenOrder, input, jobs, maxCacheSize,	false);
	}

	@SuppressWarnings("serial")
	public static void parse(final WritableTrainingData trainingData, int overtOrder, 
			int hiddenOrder, BufferedReader input, int jobs, int maxCacheSize, 
			final boolean mergeLines) throws IOException 
	{
		final int order = Math.max(overtOrder, hiddenOrder);
		final int order_m1 = order - 1;
		
		trainingData.setContextSize((byte) order_m1);
		
		final int numNullOvertFactors = Math.min(order - overtOrder, order_m1);
		final int numNullHiddenFactors = Math.min(order - hiddenOrder, order_m1);
		final TrainingDataBlock block[] = new TrainingDataBlock[1];
		block[0] = new TrainingDataBlock();
		
		final LRU<Context, Long2IntMap> tmpData = 
			new LRU<Context, Long2IntMap>(maxCacheSize) {
				@Override
				protected boolean removeEldestEntry(
						Entry<Context, Long2IntMap> eldest) {
					if (super.removeEldestEntry(eldest)) {
						// TODO: use an intermediate queue
						Long2IntMap counts = eldest.getValue();
						
						ContextFuturesPair pair = new ContextFuturesPair(eldest.getKey(), TupleCountPair.fromMap(counts));
						if (!block[0].add(pair)) {
							try {
								trainingData.add(block[0]);
								block[0] = new TrainingDataBlock();
								block[0].add(pair);
							} catch (IOException e) {
								e.printStackTrace();
							}
						}
						return true;
					}
					return false;
				}
		};

		final Lock tmpDataLock = new ReentrantLock();
		
		//int lineNo = 0;
		final Experiment exp = Experiment.getInstance();
		FactorTupleDescription desc = exp.getTupleDescription();
		FLMInputParser parser = new FLMInputParser(desc);
		{
			long[] startTuples = new long[order_m1];
			Arrays.fill(startTuples, desc.createStartTuple());
			parser.setStartTuples(startTuples);
			long[] endTuples = new long[1];
			endTuples[0] = desc.createEndTuple();
			parser.setEndTuples(endTuples);
		}

		final long overtMask = desc.getOvertFactorsMask();
		final long hiddenMask = desc.getHiddenFactorsMask();
		final FactorTuple dummyTuple = new FactorTuple();
		
		final AtomicLong totalCount = new AtomicLong();
		final AtomicLong totalSentences = new AtomicLong();
		
		ParallelInputParser.Callback parserCallback = new ParallelInputParser.Callback() {
			public void process(long[] sentence) {
				for(long tuple : sentence) {
					dummyTuple.setBits(tuple);
					if ((tuple & hiddenMask) == 0 || (tuple & overtMask) == 0 || exp.getHiddenPrefix(dummyTuple) == null) {
						System.err.printf("bad input line\n");
						return;
					}
				}
				int sentEvents = 0;
				
				for(int i=0; i<sentence.length-order_m1; ++i) {
					long[] tuples = new long[order_m1];
					for(int j=0; j<order_m1; ++j) {
						tuples[j] = sentence[i+j];
					}
					
					for(int j=0; j<numNullOvertFactors; ++j) {
						tuples[j] &= FactorTuple.getHiddenMask();
					}
					for(int j=0; j<numNullHiddenFactors; ++j) {
						tuples[j] &= FactorTuple.getOvertMask();
					}

					Context ctx = new Context(tuples);
					
					tmpDataLock.lock();
					Long2IntMap futures = tmpData.get(ctx);
					if (futures == null) {
						futures = new Long2IntMap(2);
						tmpData.put(ctx, futures);
					}
					futures.addAndGet(sentence[i+order_m1], 1);
					tmpDataLock.unlock();
					++sentEvents;
				}
				//System.err.println(sentEvents);
				totalCount.addAndGet(sentEvents);
				totalSentences.incrementAndGet();
			}
		};
		
		ParallelInputParser inputParser = new ParallelInputParser(parser, parserCallback, jobs, mergeLines);
		inputParser.parse(input);

		//System.err.printf("Total sentences: %d, totalCount: %d\n", totalSentences.longValue(), totalCount.longValue());
		
		for(Map.Entry<Context, Long2IntMap> entry : tmpData.entrySet()) {
			Long2IntMap counts = entry.getValue();
			
			ContextFuturesPair pair = new ContextFuturesPair(entry.getKey(), TupleCountPair.fromMap(entry.getValue()));
			if (!block[0].add(pair)) {
				trainingData.add(block[0]);
				block[0] = new TrainingDataBlock();
				block[0].add(pair);
			}
		}
		if (block[0].hasData()) {
			trainingData.add(block[0]);
		}
	}
}
